[mm] vision attention backend for XPU#1584
Conversation
Signed-off-by: AlpinDale <alpindale@gmail.com>
There was a problem hiding this comment.
Code Review
The pull request introduces XPU backend support for vision attention, specifically for Flash Attention. This involves modifying function signatures to make out optional and adding dropout_p, and implementing conditional logic to use ipex_ops.varlen_attention when block_table is None. Additionally, it updates the maybe_get_vit_flash_attn_backend and __init__ methods to correctly handle XPU platform specifics and adjusts import paths for flash_attn_varlen_func. The changes also include minor refactoring for attn_backend checks in qwen2_5_vl.py and qwen2_vl.py for improved readability. Overall, the changes seem to correctly integrate XPU support and maintain compatibility with existing functionalities.
| if current_platform.is_xpu(): | ||
| self.use_upstream_fa = False |
There was a problem hiding this comment.
Setting self.use_upstream_fa = False specifically for XPU is a critical change. This ensures that the custom aphrodite Flash Attention implementation is used instead of the upstream flash_attn library on XPU. This is important for compatibility and performance on XPU, but it's crucial to ensure that the custom implementation is fully tested and optimized for XPU to avoid regressions.
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| cu_seqlens_q: torch.Tensor, | ||
| seqused_k: torch.Tensor, # we don't support this in ipex kernel | ||
| max_seqlen_q: int, | ||
| max_seqlen_k: int, | ||
| softmax_scale: float, | ||
| causal: bool, | ||
| block_table: torch.Tensor, | ||
| alibi_slopes: torch.Tensor | None, | ||
| softmax_scale: float | None = None, | ||
| causal: bool = False, | ||
| out: torch.Tensor | None = None, | ||
| block_table: torch.Tensor | None = None, | ||
| alibi_slopes: torch.Tensor | None = None, | ||
| window_size: list[int] | None = None, | ||
| softcap: float | None = 0.0, | ||
| seqused_k: torch.Tensor | None = None, | ||
| cu_seqlens_k: torch.Tensor | None = None, | ||
| # passed in qwen vl | ||
| dropout_p: float = 0.0, |
There was a problem hiding this comment.
The reordering of parameters in flash_attn_varlen_func makes the function signature less intuitive. While out is now optional, it's generally good practice to keep required parameters before optional ones. Moving out to the end, or at least after all other required parameters, would improve readability and maintain consistency with common Python function signature conventions.
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| cu_seqlens_q: torch.Tensor, | |
| seqused_k: torch.Tensor, # we don't support this in ipex kernel | |
| max_seqlen_q: int, | |
| max_seqlen_k: int, | |
| softmax_scale: float, | |
| causal: bool, | |
| block_table: torch.Tensor, | |
| alibi_slopes: torch.Tensor | None, | |
| softmax_scale: float | None = None, | |
| causal: bool = False, | |
| out: torch.Tensor | None = None, | |
| block_table: torch.Tensor | None = None, | |
| alibi_slopes: torch.Tensor | None = None, | |
| window_size: list[int] | None = None, | |
| softcap: float | None = 0.0, | |
| seqused_k: torch.Tensor | None = None, | |
| cu_seqlens_k: torch.Tensor | None = None, | |
| # passed in qwen vl | |
| dropout_p: float = 0.0, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| cu_seqlens_q: torch.Tensor, | |
| max_seqlen_q: int, | |
| max_seqlen_k: int, | |
| softmax_scale: float | None = None, | |
| causal: bool = False, | |
| block_table: torch.Tensor | None = None, | |
| alibi_slopes: torch.Tensor | None = None, | |
| window_size: list[int] | None = None, | |
| softcap: float | None = 0.0, | |
| seqused_k: torch.Tensor | None = None, | |
| cu_seqlens_k: torch.Tensor | None = None, | |
| dropout_p: float = 0.0, | |
| out: torch.Tensor | None = None, |
| if out is None: | ||
| out = torch.empty(q.shape, dtype=q.dtype, device=q.device) |
There was a problem hiding this comment.
Initializing out with torch.empty inside the function when out is None is a good approach. However, consider if q.shape, q.dtype, and q.device are always the correct attributes to use for out's initialization, especially if out might have a different expected shape or device in some edge cases not covered by the current logic. If out is intended to be the same shape as q, this is fine.
| q.contiguous(), | ||
| k.contiguous(), | ||
| v.contiguous(), | ||
| out, | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| None, | ||
| max_seqlen_q, | ||
| max_seqlen_k, | ||
| 0.0, | ||
| softmax_scale, | ||
| False, | ||
| causal, | ||
| False, | ||
| None, | ||
| real_window_size[0], | ||
| real_window_size[1], | ||
| -1, | ||
| ) |
There was a problem hiding this comment.
The ipex_ops.varlen_attention call has several hardcoded values (e.g., 0.0 for dropout, False for is_causal, False for return_softmax, None for attn_mask, -1 for num_splits). While these might be the default or desired values for this specific use case, it's generally better to pass them as explicit arguments or derive them from existing parameters if they are configurable elsewhere. This improves clarity and makes the function more flexible for future changes.
| elif current_platform.is_xpu(): | ||
| assert attn_backend == _Backend.FLASH_ATTN, "XPU platform only supports FLASH_ATTN as vision attention backend." | ||
| use_upstream_fa = False |
There was a problem hiding this comment.
The assert statement here will raise an AssertionError if attn_backend is not _Backend.FLASH_ATTN on XPU. While this ensures the correct backend is used, it might be more user-friendly to raise a custom exception (e.g., ValueError) with a more descriptive message, or to handle this case gracefully by falling back to a default backend if possible, rather than crashing the program. This depends on the expected behavior and error handling strategy for the application.
| self.attn_backend = ( | ||
| backend | ||
| if backend | ||
| in { | ||
| _Backend.TORCH_SDPA, | ||
| _Backend.XFORMERS, | ||
| _Backend.PALLAS, | ||
| _Backend.ROCM_AITER_FA, | ||
| _Backend.FLASH_ATTN, | ||
| } | ||
| else _Backend.TORCH_SDPA | ||
| ) |
There was a problem hiding this comment.
The removal of the if current_platform.is_xpu(): block and the direct assignment of self.attn_backend means that the XPU platform no longer explicitly defaults to _Backend.TORCH_SDPA if backend is not FLASH_ATTN. This change aligns with the maybe_get_vit_flash_attn_backend function, which now handles XPU-specific backend selection. Ensure that this change doesn't inadvertently allow unsupported backends on XPU or introduce unexpected behavior if backend is not FLASH_ATTN when running on XPU.
| def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: | ||
| from aphrodite.attention.backends.registry import _Backend | ||
|
|
||
| return _Backend.FLASH_ATTN |
There was a problem hiding this comment.
The get_vit_attn_backend method for XPU explicitly returns _Backend.FLASH_ATTN. This hardcodes the vision attention backend for XPU to Flash Attention. While this might be the current strategy, consider if there's a need for flexibility in the future to support other backends or to dynamically determine the best backend based on hardware capabilities or user preferences. If not, this explicit setting is clear.
No description provided.